{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "基本原理和离散动作是一样的,连续动作的概率使用高斯密度函数计算即可."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "image/png": "iVBORw0KGgoAAAANSUhEUgAAAR8AAAEXCAYAAACUBEAgAAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjYuMywgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy/P9b71AAAACXBIWXMAAA9hAAAPYQGoP6dpAAAcy0lEQVR4nO3dfXCTZbo/8G/SNOlrUgo0odJKzwGBHiirBUrWM+M5S6Vq1/UFz7gOqx2XoyMGBsRh1u5KnXV3pvxwZl3dVXRmZ4Xzh3YHZ6srC7o9Rcs6hAKFLqVAV49oeyhJhW5e6EuaJtfvD+lzDFRM2qZ3Qr+fmWeGPPeV5noK+fLkfl6iExEBEdEk06tugIimJoYPESnB8CEiJRg+RKQEw4eIlGD4EJESDB8iUoLhQ0RKMHyISAmGDxEpoSx8XnnlFcyZMwdpaWkoKyvD4cOHVbVCRAooCZ8//OEP2Lx5M5577jkcO3YMS5YsQUVFBXp6elS0Q0QK6FRcWFpWVoZly5bht7/9LQAgHA6joKAAGzZswDPPPPOtzw+Hw+ju7kZ2djZ0Ol282yWiKIkI/H4/8vPzoddfe9/GMEk9aYaGhtDS0oLq6mptnV6vR3l5OZxO56jPCQQCCAQC2uNz586huLg47r0S0dh0dXVh9uzZ16yZ9PC5cOECQqEQrFZrxHqr1YozZ86M+pza2lr8/Oc/v2p9V1cXzGZzXPokotj5fD4UFBQgOzv7W2snPXzGorq6Gps3b9Yej2yg2Wxm+BAloGimQyY9fGbMmIGUlBS43e6I9W63GzabbdTnmEwmmEymyWiPiCbJpB/tMhqNKC0tRWNjo7YuHA6jsbERdrt9stshIkWUfOzavHkzqqqqsHTpUixfvhy//vWv0dfXh0cffVRFO0SkgJLwefDBB/Hll1+ipqYGLpcL3/nOd/D+++9fNQlNRNcvJef5jJfP54PFYoHX6+WEM1ECieW9yWu7iEgJhg8RKcHwISIlGD5EpATDh4iUYPgQkRIMHyJSguFDREowfIhICYYPESnB8CEiJRg+RKQEw4eIlGD4EJESDB8iUoLhQ0RKMHyISAmGDxEpwfAhIiUYPkSkBMOHiJRg+BCREgwfIlKC4UNESjB8iEgJhg8RKcHwISIlGD5EpATDh4iUYPgQkRIMHyJSguFDREowfIhICYYPESnB8CEiJWIOnwMHDuDuu+9Gfn4+dDod3nnnnYhxEUFNTQ1mzZqF9PR0lJeX45NPPomo6e3txZo1a2A2m5GTk4O1a9fi0qVL49oQIkouMYdPX18flixZgldeeWXU8e3bt+Pll1/Ga6+9hubmZmRmZqKiogKDg4NazZo1a9De3o6Ghgbs2bMHBw4cwOOPPz72rSCi5CPjAEDq6+u1x+FwWGw2m7zwwgvaOo/HIyaTSd566y0RETl16pQAkCNHjmg1+/btE51OJ+fOnYvqdb1erwAQr9c7nvaJaILF8t6c0Dmfs2fPwuVyoby8XFtnsVhQVlYGp9MJAHA6ncjJycHSpUu1mvLycuj1ejQ3N4/6cwOBAHw+X8RCRMltQsPH5XIBAKxWa8R6q9WqjblcLuTl5UWMGwwG5ObmajVXqq2thcVi0ZaCgoKJbJuIFEiKo13V1dXwer3a0tXVpbolIhqnCQ0fm80GAHC73RHr3W63Nmaz2dDT0xMxPjw8jN7eXq3mSiaTCWazOWIhouQ2oeFTVFQEm82GxsZGbZ3P50NzczPsdjsAwG63w+PxoKWlRavZv38/wuEwysrKJrIdIkpghlifcOnSJXz66afa47Nnz6K1tRW5ubkoLCzEpk2b8Mtf/hLz5s1DUVERtm7divz8fNx7770AgIULF+KOO+7AY489htdeew3BYBDr16/HD3/4Q+Tn50/YhhFRgov1UNqHH34oAK5aqqqqROSrw+1bt24Vq9UqJpNJVq5cKR0dHRE/4+LFi/LQQw9JVlaWmM1mefTRR8Xv90fdAw+1EyWmWN6bOhERhdk3Jj6fDxaLBV6vl/M/RAkklvdmUhztIqLrD8OHiJRg+BCREgwfIlKC4UNESjB8iEgJhg8RKcHwISIlGD5EpATDh4iUYPgQkRIMHyJSguFDREowfIhICYYPESnB8CEiJRg+RKQEw4eIlGD4EJESDB8iUoLhQ0RKMHyISAmGDxEpwfAhIiUYPkSkBMOHiJRg+BCREgwfIlKC4UNESjB8iEgJhg8RKcHwISIlGD5EpATDh4iUYPgQkRIxhU9tbS2WLVuG7Oxs5OXl4d5770VHR0dEzeDgIBwOB6ZPn46srCysXr0abrc7oqazsxOVlZXIyMhAXl4etmzZguHh4fFvDREljZjCp6mpCQ6HA4cOHUJDQwOCwSBWrVqFvr4+reapp57Ce++9h927d6OpqQnd3d24//77tfFQKITKykoMDQ3h4MGD2LVrF3bu3ImampqJ2yoiSnwyDj09PQJAmpqaRETE4/FIamqq7N69W6s5ffq0ABCn0ykiInv37hW9Xi8ul0ur2bFjh5jNZgkEAlG9rtfrFQDi9XrH0z4RTbBY3pvjmvPxer0AgNzcXABAS0sLgsEgysvLtZoFCxagsLAQTqcTAOB0OrF48WJYrVatpqKiAj6fD+3t7aO+TiAQgM/ni1iIKLmNOXzC4TA2bdqEW2+9FYsWLQIAuFwuGI1G5OTkRNRarVa4XC6t5uvBMzI+Mjaa2tpaWCwWbSkoKBhr20SUIMYcPg6HAydPnkRdXd1E9jOq6upqeL1ebenq6or7axJRfBnG8qT169djz549OHDgAGbPnq2tt9lsGBoagsfjidj7cbvdsNlsWs3hw4cjft7I0bCRmiuZTCaYTKaxtEpECSqmPR8Rwfr161FfX4/9+/ejqKgoYry0tBSpqalobGzU1nV0dKCzsxN2ux0AYLfb0dbWhp6eHq2moaEBZrMZxcXF49kWIkoiMe35OBwOvPnmm3j33XeRnZ2tzdFYLBakp6fDYrFg7dq12Lx5M3Jzc2E2m7FhwwbY7XasWLECALBq1SoUFxfj4Ycfxvbt2+FyufDss8/C4XBw74ZoKonlMBqAUZc33nhDqxkYGJAnn3xSpk2bJhkZGXLffffJ+fPnI37O559/Lnfeeaekp6fLjBkz5Omnn5ZgMBh1HzzUTpSYYnlv6kRE1EXf2Ph8PlgsFni9XpjNZtXtENFlsbw3eW0XESnB8CEiJRg+RKQEw4eIlGD4EJESDB8iUoLhQ0RKMHyISAmGDxEpMaar2okmgoTDCPX1ITw0BF1KClLS06EzGqHT6VS3RpOA4UOTTkQQ7O3Fl/v2wXv4MIYuXIDeZELG3LnIu+suZJeUQJeSorpNijOGD00qEUGguxufv/QS+jo6gMuXFoYuXYL34kVcamtD/iOPYGZFBQPoOsc5H5pUof5+dL7+OvrOnIGEw/hHIICjFy7gE58PYRGE+vtx7r/+C95jx5CE1zxTDLjnQ5NGROBtbob/xAmICDr7+rD1+HF0eL3INBjwnzfdhAeLioD+frjr65G9aBFS0tNVt01xwj0fmjQBlwvndu0CwmEIgP/X1oZTHg9CIvAFg/jt6dM4+Y9/AAD6P/sMoYEBtQ1TXDF8aNL4T5xA8PLXLQGALxiMGB8KhxEIhSa7LVKE4UOTQsJhDHzxBRAOAwB0AP7dZoPha4fVbzKbcWNWlqIOabJxzocmRXhwEP4TJ7THOp0OVXPnIjs1Ff99/jxmpafjsZtuQl5aGgB8dc6Pnv83Xs8YPjQphv1+BC/P54ww6PX4jzlz8MCcORjZ/xk5wdB8880w8Ba51zWGD8WdiMB/4gRCfX1Xjel0Olx1PrNOB9OsWdzzuc7xb5fiTwQBl0ub7/k2upQUWJYti3NTpBrDh+JOhofhOXQo6nrjzJlItVji2BElAoYPxV2ory+mc3bSi4pgYPhc9xg+FHd9n36KYY8n6vqshQs53zMF8G+Y4kpEEDh3DjI8HN0TdDpkLlgQ36YoITB8KL5E4Glujro87YYbYLJa49gQJQqGD8XVsN+PYG9v1PXGvDzO90wRDB+KGxHBwNmzCLjdUT+Hh9inDoYPxdVAV1fU5/cgJQXphYW8jeoUwfChuPLGMN9jstmQPmdO/JqhhMLwobgZ9noxdPFi1PXGmTORkpkZx44okTB8KG4CbjcC585FXZ+zfDnAj1xTBsOH4iZw/nzUtTqjEWmzZ3O+Zwph+FBciAi8R49GXZ86bRoy/vmf49gRJRqGD8VFqL//qyvZo2TIzIT+8o3EaGpg+FBcBC9exMBnn0VdP+1f/xU6A28vNZXEFD47duxASUkJzGYzzGYz7HY79u3bp40PDg7C4XBg+vTpyMrKwurVq+G+4gSzzs5OVFZWIiMjA3l5ediyZQuGo73uh5JG/2efQWI4v8c4c2Z8G6KEE1P4zJ49G9u2bUNLSwuOHj2K733ve7jnnnvQ3t4OAHjqqafw3nvvYffu3WhqakJ3dzfuv/9+7fmhUAiVlZUYGhrCwYMHsWvXLuzcuRM1NTUTu1WklIjgUltb1CcXGrKzkb1kCSebpxidjPNrIXNzc/HCCy/ggQcewMyZM/Hmm2/igQceAACcOXMGCxcuhNPpxIoVK7Bv3z58//vfR3d3N6yXLx587bXX8JOf/ARffvkljEbjqK8RCAQQCAS0xz6fDwUFBfB6vTDzPr8JJ9TXh7/X1KD/k0+iqs+cPx/zfvELpHDOJ+n5fD5YLJao3ptjnvMJhUKoq6tDX18f7HY7WlpaEAwGUV5ertUsWLAAhYWFcDqdAACn04nFixdrwQMAFRUV8Pl82t7TaGpra2GxWLSloKBgrG3TJAj192Owqyvq+uySEuhNpjh2RIko5vBpa2tDVlYWTCYTnnjiCdTX16O4uBgulwtGoxE5OTkR9VarFa7LRz1cLldE8IyMj4x9k+rqani9Xm3piuEfNk2+gc5OSJRf/qczGJBeVMSPXFNQzIcX5s+fj9bWVni9Xrz99tuoqqpCU1NTPHrTmEwmmPg/Y1IQEVxqb4dc8W2k30RnNCJr/vw4d0WJKObwMRqNmDt3LgCgtLQUR44cwUsvvYQHH3wQQ0ND8Hg8EXs/brcbNpsNAGCz2XD48OGInzdyNGykhpKbBIPoP3s26vqMoiKk8FtKp6Rxn+cTDocRCARQWlqK1NRUNDY2amMdHR3o7OyE3W4HANjtdrS1taGnp0eraWhogNlsRnFx8XhboQQQGhhA/6efRl2ffuONPLlwioppz6e6uhp33nknCgsL4ff78eabb+Kjjz7CBx98AIvFgrVr12Lz5s3Izc2F2WzGhg0bYLfbsWLFCgDAqlWrUFxcjIcffhjbt2+Hy+XCs88+C4fDwY9V14n+//mf6L+pIiWFh9insJjCp6enB4888gjOnz8Pi8WCkpISfPDBB7j99tsBAC+++CL0ej1Wr16NQCCAiooKvPrqq9rzU1JSsGfPHqxbtw52ux2ZmZmoqqrC888/P7FbRUqICAY+/xwyNBRVvd5gQPqNN8a5K0pU4z7PR4VYziWgyRMOBvHpL34Bf2trVPVZixZhbk0Nz++5jkzKeT5EV5JgEIHu7qjrTbNm8fyeKYzhQxNm8H//F6G+vuiKdTrkLF/O+Z4pjOFDE2Jkvifa8NGnpcE0a1acu6JExvChiSEC/8mTUZenFxbyywGnOIYPTQgJBtEfw/17DGYz53umOIYPTYiBzk4Mffll1PW5t90Wx24oGTB8aNxEBIGeHoSjPLlQZzDw5mHE8KGJ4Tl4MOratIICZPzTP8WxG0oGDB8aNxkextDXrtf7NsaZM6H7hhvH0dTB8KFxG+zsjOlK9mmXLzSmqY3hQ+MW9Pmiv57r8vk9PLmQGD40br6WlqhrU6dNQzrnewgMHxqncDAY0/2aM+bOhT41NY4dUbJg+NC4BC9cQF8MNw/LWrAA0POfHTF8aJyGenujv54rIwMZ8+ZxvocAMHxoHEQEHqcz6i8HTMnIQNrs2XHuipIFw4fGTIaHY7qkInvRIt6vmTQMHxqzYa8X/hMnoq5PLyyELiUljh1RMmH40NiJRP3lgPq0NGSXlHC+hzQMH5oUKZmZMPL+PfQ1DB+aFFn/8i9IycxU3QYlEIYPjZnBYkF2Scm31ulSUzHj9tuhN8T8Bbl0HWP40JjpjUbYHngAqdOnX6NIj+krVyJr4cLJa4ySAsOHxiXzpptQuG7dqPM5utRU5P7bv+GGH/0Iet5Cg67A/WAaF51eD8uyZTDl56N3/37429sRHhyEadYsTLv1VliWLeOXAtKoGD40bjqdDmk33ID8H/0IEPlq0ekAvZ6H1ukbMXxoQuh0uq8ChyhKnPMhIiUYPkSkBMOHiJRg+BCREgwfIlKC4UNESjB8iEgJhg8RKTGu8Nm2bRt0Oh02bdqkrRscHITD4cD06dORlZWF1atXw+12Rzyvs7MTlZWVyMjIQF5eHrZs2YLh4eHxtEJESWbM4XPkyBG8/vrrKLnilgpPPfUU3nvvPezevRtNTU3o7u7G/fffr42HQiFUVlZiaGgIBw8exK5du7Bz507U1NSMfSuIKPnIGPj9fpk3b540NDTIbbfdJhs3bhQREY/HI6mpqbJ7926t9vTp0wJAnE6niIjs3btX9Hq9uFwurWbHjh1iNpslEAiM+nqDg4Pi9Xq1paurSwCI1+sdS/tEFCderzfq9+aY9nwcDgcqKytRXl4esb6lpQXBYDBi/YIFC1BYWAin0wkAcDqdWLx4MaxfuwVDRUUFfD4f2tvbR3292tpaWCwWbSkoKBhL20SUQGIOn7q6Ohw7dgy1tbVXjblcLhiNRuTk5ESst1qtcLlcWo31inu/jDweqblSdXU1vF6vtnTF8PW8RJSYYrqqvaurCxs3bkRDQwPSJvEeLSaTCSaTadJej4jiL6Y9n5aWFvT09OCWW26BwWCAwWBAU1MTXn75ZRgMBlitVgwNDcHj8UQ8z+12w2azAQBsNttVR79GHo/UENH1L6bwWblyJdra2tDa2qotS5cuxZo1a7Q/p6amorGxUXtOR0cHOjs7YbfbAQB2ux1tbW3o6enRahoaGmA2m1FcXDxBm0VEiS6mj13Z2dlYtGhRxLrMzExMnz5dW7927Vps3rwZubm5MJvN2LBhA+x2O1asWAEAWLVqFYqLi/Hwww9j+/btcLlcePbZZ+FwOPjRimgKmfA7Gb744ovQ6/VYvXo1AoEAKioq8Oqrr2rjKSkp2LNnD9atWwe73Y7MzExUVVXh+eefn+hWiCiB6UREVDcRK5/PB4vFAq/XC7PZrLodIroslvcmr+0iIiUYPkSkBMOHiJRg+BCREgwfIlKC4UNESjB8iEgJhg8RKcHwISIlGD5EpATDh4iUYPgQkRIMHyJSguFDREowfIhICYYPESnB8CEiJRg+RKQEw4eIlGD4EJESDB8iUoLhQ0RKMHyISAmGDxEpwfAhIiUYPkSkBMOHiJRg+BCREgwfIlKC4UNESjB8iEgJhg8RKcHwISIlGD5EpATDh4iUYPgQkRIMHyJSwqC6gbEQEQCAz+dT3AkRfd3Ie3LkPXotSRk+Fy9eBAAUFBQo7oSIRuP3+2GxWK5Zk5Thk5ubCwDo7Oz81g1MND6fDwUFBejq6oLZbFbdTtTY9+RK1r5FBH6/H/n5+d9am5Tho9d/NVVlsViS6i/m68xmc1L2zr4nVzL2He0OASeciUgJhg8RKZGU4WMymfDcc8/BZDKpbiVmydo7+55cydp3LHQSzTExIqIJlpR7PkSU/Bg+RKQEw4eIlGD4EJESDB8iUiIpw+eVV17BnDlzkJaWhrKyMhw+fFhpPwcOHMDdd9+N/Px86HQ6vPPOOxHjIoKamhrMmjUL6enpKC8vxyeffBJR09vbizVr1sBsNiMnJwdr167FpUuX4tp3bW0tli1bhuzsbOTl5eHee+9FR0dHRM3g4CAcDgemT5+OrKwsrF69Gm63O6Kms7MTlZWVyMjIQF5eHrZs2YLh4eG49b1jxw6UlJRoZ//a7Xbs27cvoXsezbZt26DT6bBp06ak631CSJKpq6sTo9Eov//976W9vV0ee+wxycnJEbfbraynvXv3ys9+9jP54x//KACkvr4+Ynzbtm1isVjknXfekb/97W/ygx/8QIqKimRgYECrueOOO2TJkiVy6NAh+etf/ypz586Vhx56KK59V1RUyBtvvCEnT56U1tZWueuuu6SwsFAuXbqk1TzxxBNSUFAgjY2NcvToUVmxYoV897vf1caHh4dl0aJFUl5eLsePH5e9e/fKjBkzpLq6Om59/+lPf5I///nP8ve//106Ojrkpz/9qaSmpsrJkycTtucrHT58WObMmSMlJSWyceNGbX0y9D5Rki58li9fLg6HQ3scCoUkPz9famtrFXb1f64Mn3A4LDabTV544QVtncfjEZPJJG+99ZaIiJw6dUoAyJEjR7Saffv2iU6nk3Pnzk1a7z09PQJAmpqatD5TU1Nl9+7dWs3p06cFgDidThH5Knj1er24XC6tZseOHWI2myUQCExa79OmTZPf/e53SdGz3++XefPmSUNDg9x2221a+CRD7xMpqT52DQ0NoaWlBeXl5do6vV6P8vJyOJ1OhZ19s7Nnz8LlckX0bLFYUFZWpvXsdDqRk5ODpUuXajXl5eXQ6/Vobm6etF69Xi+A/7trQEtLC4LBYETvCxYsQGFhYUTvixcvhtVq1WoqKirg8/nQ3t4e955DoRDq6urQ19cHu92eFD07HA5UVlZG9Agkx+97IiXVVe0XLlxAKBSK+MUDgNVqxZkzZxR1dW0ulwsARu15ZMzlciEvLy9i3GAwIDc3V6uJt3A4jE2bNuHWW2/FokWLtL6MRiNycnKu2fto2zYyFi9tbW2w2+0YHBxEVlYW6uvrUVxcjNbW1oTtGQDq6upw7NgxHDly5KqxRP59x0NShQ/Fj8PhwMmTJ/Hxxx+rbiUq8+fPR2trK7xeL95++21UVVWhqalJdVvX1NXVhY0bN6KhoQFpaWmq21EuqT52zZgxAykpKVfN/rvdbthsNkVdXdtIX9fq2WazoaenJ2J8eHgYvb29k7Jd69evx549e/Dhhx9i9uzZ2nqbzYahoSF4PJ5r9j7ato2MxYvRaMTcuXNRWlqK2tpaLFmyBC+99FJC99zS0oKenh7ccsstMBgMMBgMaGpqwssvvwyDwQCr1ZqwvcdDUoWP0WhEaWkpGhsbtXXhcBiNjY2w2+0KO/tmRUVFsNlsET37fD40NzdrPdvtdng8HrS0tGg1+/fvRzgcRllZWdx6ExGsX78e9fX12L9/P4qKiiLGS0tLkZqaGtF7R0cHOjs7I3pva2uLCM+GhgaYzWYUFxfHrfcrhcNhBAKBhO555cqVaGtrQ2trq7YsXboUa9as0f6cqL3HheoZ71jV1dWJyWSSnTt3yqlTp+Txxx+XnJyciNn/yeb3++X48eNy/PhxASC/+tWv5Pjx4/LFF1+IyFeH2nNycuTdd9+VEydOyD333DPqofabb75Zmpub5eOPP5Z58+bF/VD7unXrxGKxyEcffSTnz5/Xlv7+fq3miSeekMLCQtm/f78cPXpU7Ha72O12bXzk0O+qVauktbVV3n//fZk5c2ZcD/0+88wz0tTUJGfPnpUTJ07IM888IzqdTv7yl78kbM/f5OtHu5Kt9/FKuvAREfnNb34jhYWFYjQaZfny5XLo0CGl/Xz44YcC4KqlqqpKRL463L5161axWq1iMplk5cqV0tHREfEzLl68KA899JBkZWWJ2WyWRx99VPx+f1z7Hq1nAPLGG29oNQMDA/Lkk0/KtGnTJCMjQ+677z45f/58xM/5/PPP5c4775T09HSZMWOGPP300xIMBuPW949//GO58cYbxWg0ysyZM2XlypVa8CRqz9/kyvBJpt7Hi/fzISIlkmrOh4iuHwwfIlKC4UNESjB8iEgJhg8RKcHwISIlGD5EpATDh4iUYPgQkRIMHyJSguFDREr8f6Ju43VKni04AAAAAElFTkSuQmCC\n",
      "text/plain": [
       "<Figure size 300x300 with 1 Axes>"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    }
   ],
   "source": [
    "import gym\n",
    "\n",
    "\n",
    "#定义环境\n",
    "class MyWrapper(gym.Wrapper):\n",
    "\n",
    "    def __init__(self):\n",
    "        env = gym.make('Pendulum-v1', render_mode='rgb_array')\n",
    "        super().__init__(env)\n",
    "        self.env = env\n",
    "        self.step_n = 0\n",
    "\n",
    "    def reset(self):\n",
    "        state, _ = self.env.reset()\n",
    "        self.step_n = 0\n",
    "        return state\n",
    "\n",
    "    def step(self, action):\n",
    "        state, reward, terminated, truncated, info = self.env.step(\n",
    "            [action * 2])\n",
    "        over = terminated or truncated\n",
    "\n",
    "        #偏移reward,便于训练\n",
    "        reward = (reward + 8) / 8\n",
    "\n",
    "        #限制最大步数\n",
    "        self.step_n += 1\n",
    "        if self.step_n >= 200:\n",
    "            over = True\n",
    "\n",
    "        return state, reward, over\n",
    "\n",
    "    #打印游戏图像\n",
    "    def show(self):\n",
    "        from matplotlib import pyplot as plt\n",
    "        plt.figure(figsize=(3, 3))\n",
    "        plt.imshow(self.env.render())\n",
    "        plt.show()\n",
    "\n",
    "\n",
    "env = MyWrapper()\n",
    "\n",
    "env.reset()\n",
    "\n",
    "env.show()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "((tensor([[0.0786],\n",
       "          [0.1098]], grad_fn=<TanhBackward0>),\n",
       "  tensor([[1.1475],\n",
       "          [0.9288]], grad_fn=<ExpBackward0>)),\n",
       " tensor([[-0.0460],\n",
       "         [-0.0953]], grad_fn=<AddmmBackward0>))"
      ]
     },
     "execution_count": 2,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "import torch\n",
    "\n",
    "\n",
    "#定义模型\n",
    "class Model(torch.nn.Module):\n",
    "\n",
    "    def __init__(self):\n",
    "        super().__init__()\n",
    "        self.s = torch.nn.Sequential(\n",
    "            torch.nn.Linear(3, 64),\n",
    "            torch.nn.ReLU(),\n",
    "            torch.nn.Linear(64, 64),\n",
    "            torch.nn.ReLU(),\n",
    "        )\n",
    "        self.mu = torch.nn.Sequential(\n",
    "            torch.nn.Linear(64, 1),\n",
    "            torch.nn.Tanh(),\n",
    "        )\n",
    "        self.sigma = torch.nn.Sequential(\n",
    "            torch.nn.Linear(64, 1),\n",
    "            torch.nn.Tanh(),\n",
    "        )\n",
    "\n",
    "    def forward(self, state):\n",
    "        state = self.s(state)\n",
    "\n",
    "        return self.mu(state), self.sigma(state).exp()\n",
    "\n",
    "\n",
    "model_action = Model()\n",
    "\n",
    "model_value = torch.nn.Sequential(\n",
    "    torch.nn.Linear(3, 64),\n",
    "    torch.nn.ReLU(),\n",
    "    torch.nn.Linear(64, 64),\n",
    "    torch.nn.ReLU(),\n",
    "    torch.nn.Linear(64, 1),\n",
    ")\n",
    "\n",
    "model_action(torch.randn(2, 3)), model_value(torch.randn(2, 3))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {
    "scrolled": true
   },
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "/tmp/ipykernel_4775/3994207745.py:34: UserWarning: Creating a tensor from a list of numpy.ndarrays is extremely slow. Please consider converting the list to a single numpy.ndarray with numpy.array() before converting to a tensor. (Triggered internally at  ../torch/csrc/utils/tensor_new.cpp:201.)\n",
      "  state = torch.FloatTensor(state).reshape(-1, 3)\n"
     ]
    },
    {
     "data": {
      "text/plain": [
       "17.590099334716797"
      ]
     },
     "execution_count": 3,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "from IPython import display\n",
    "import random\n",
    "\n",
    "\n",
    "#玩一局游戏并记录数据\n",
    "def play(show=False):\n",
    "    state = []\n",
    "    action = []\n",
    "    reward = []\n",
    "    next_state = []\n",
    "    over = []\n",
    "\n",
    "    s = env.reset()\n",
    "    o = False\n",
    "    while not o:\n",
    "        #根据概率采样\n",
    "        mu, sigma = model_action(torch.FloatTensor(s).reshape(1, 3))\n",
    "        a = random.normalvariate(mu=mu.item(), sigma=sigma.item())\n",
    "\n",
    "        ns, r, o = env.step(a)\n",
    "\n",
    "        state.append(s)\n",
    "        action.append(a)\n",
    "        reward.append(r)\n",
    "        next_state.append(ns)\n",
    "        over.append(o)\n",
    "\n",
    "        s = ns\n",
    "\n",
    "        if show:\n",
    "            display.clear_output(wait=True)\n",
    "            env.show()\n",
    "\n",
    "    state = torch.FloatTensor(state).reshape(-1, 3)\n",
    "    action = torch.FloatTensor(action).reshape(-1, 1)\n",
    "    reward = torch.FloatTensor(reward).reshape(-1, 1)\n",
    "    next_state = torch.FloatTensor(next_state).reshape(-1, 3)\n",
    "    over = torch.LongTensor(over).reshape(-1, 1)\n",
    "\n",
    "    return state, action, reward, next_state, over, reward.sum().item()\n",
    "\n",
    "\n",
    "state, action, reward, next_state, over, reward_sum = play()\n",
    "\n",
    "reward_sum"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {},
   "outputs": [],
   "source": [
    "optimizer_action = torch.optim.Adam(model_action.parameters(), lr=5e-4)\n",
    "optimizer_value = torch.optim.Adam(model_value.parameters(), lr=5e-3)\n",
    "\n",
    "\n",
    "def requires_grad(model, value):\n",
    "    for param in model.parameters():\n",
    "        param.requires_grad_(value)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "torch.Size([200, 1])"
      ]
     },
     "execution_count": 5,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "def train_value(state, reward, next_state, over):\n",
    "    requires_grad(model_action, False)\n",
    "    requires_grad(model_value, True)\n",
    "\n",
    "    #计算target\n",
    "    with torch.no_grad():\n",
    "        target = model_value(next_state)\n",
    "    target = target * 0.98 * (1 - over) + reward\n",
    "\n",
    "    #每批数据反复训练10次\n",
    "    for _ in range(10):\n",
    "        #计算value\n",
    "        value = model_value(state)\n",
    "\n",
    "        loss = torch.nn.functional.mse_loss(value, target)\n",
    "        loss.backward()\n",
    "        optimizer_value.step()\n",
    "        optimizer_value.zero_grad()\n",
    "\n",
    "    #减去value相当于去基线\n",
    "    return (target - value).detach()\n",
    "\n",
    "\n",
    "value = train_value(state, reward, next_state, over)\n",
    "\n",
    "value.shape"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "0.07272527366876602"
      ]
     },
     "execution_count": 6,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "def train_action(state, action, value):\n",
    "    requires_grad(model_action, True)\n",
    "    requires_grad(model_value, False)\n",
    "\n",
    "    #计算当前state的价值,其实就是Q(state,action),这里是用蒙特卡洛法估计的\n",
    "    delta = []\n",
    "    for i in range(len(value)):\n",
    "        s = 0\n",
    "        for j in range(i, len(value)):\n",
    "            s += value[j] * (0.9 * 0.9)**(j - i)\n",
    "        delta.append(s)\n",
    "    delta = torch.FloatTensor(delta).reshape(-1, 1)\n",
    "\n",
    "    #更新前的动作概率\n",
    "    with torch.no_grad():\n",
    "        mu, sigma = model_action(state)\n",
    "        prob_old = torch.distributions.Normal(mu, sigma).log_prob(action).exp()\n",
    "\n",
    "    #每批数据反复训练10次\n",
    "    for _ in range(10):\n",
    "        #更新后的动作概率\n",
    "        mu, sigma = model_action(state)\n",
    "        prob_new = torch.distributions.Normal(mu, sigma).log_prob(action).exp()\n",
    "\n",
    "        #求出概率的变化\n",
    "        ratio = prob_new / prob_old\n",
    "\n",
    "        #计算截断的和不截断的两份loss,取其中小的\n",
    "        surr1 = ratio * delta\n",
    "        surr2 = ratio.clamp(0.8, 1.2) * delta\n",
    "\n",
    "        loss = -torch.min(surr1, surr2).mean()\n",
    "\n",
    "        #更新参数\n",
    "        loss.backward()\n",
    "        optimizer_action.step()\n",
    "        optimizer_action.zero_grad()\n",
    "\n",
    "    return loss.item()\n",
    "\n",
    "\n",
    "train_action(state, action, value)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "0 0.10162344574928284 51.45712678432464\n",
      "100 -0.7537704706192017 107.65102710723878\n",
      "200 0.5636732578277588 85.48153505325317\n",
      "300 0.16421520709991455 95.98738517761231\n",
      "400 0.19185616075992584 117.04133224487305\n",
      "500 -0.3527289927005768 139.04767227172852\n",
      "600 -0.32614269852638245 163.88694686889647\n",
      "700 0.14939337968826294 50.68153982162475\n",
      "800 -0.23504677414894104 125.99543991088868\n",
      "900 0.9372154474258423 140.94840087890626\n"
     ]
    }
   ],
   "source": [
    "def train():\n",
    "    model_action.train()\n",
    "    model_value.train()\n",
    "\n",
    "    #训练N局\n",
    "    for epoch in range(1000):\n",
    "        #一个epoch最少玩N步\n",
    "        steps = 0\n",
    "        while steps < 200:\n",
    "            state, action, reward, next_state, over, _ = play()\n",
    "            steps += len(state)\n",
    "\n",
    "            #训练两个模型\n",
    "            delta = train_value(state, reward, next_state, over)\n",
    "            loss = train_action(state, action, delta)\n",
    "\n",
    "        if epoch % 100 == 0:\n",
    "            test_result = sum([play()[-1] for _ in range(20)]) / 20\n",
    "            print(epoch, loss, test_result)\n",
    "\n",
    "\n",
    "train()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "image/png": "iVBORw0KGgoAAAANSUhEUgAAAR8AAAEXCAYAAACUBEAgAAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjYuMywgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy/P9b71AAAACXBIWXMAAA9hAAAPYQGoP6dpAAAbQElEQVR4nO3df2xT973/8Zed2M7P4xAgdlOSEl1YIV9+rA0QvH6/7TQ80i7rysiuOoRY1nHbCzWIHxMa2VqqVdMNotJY2fgxfasW/mlTpXehI4N2UaBhvXX5kZI2BEg7iS65gB1+LHYIjePE7/sHzbkYUhonsT928npIlppzPknepvjJiU/sYxARARFRjBlVD0BE4xPjQ0RKMD5EpATjQ0RKMD5EpATjQ0RKMD5EpATjQ0RKMD5EpATjQ0RKKIvPzp07MXXqVKSkpKC4uBjHjx9XNQoRKaAkPm+++SY2btyIF154AR999BHmzp2LkpISdHR0qBiHiBQwqHhhaXFxMebPn48//OEPAIBQKIS8vDysXbsWmzdv/trPD4VCuHjxIjIzM2EwGKI9LhENkYigq6sLubm5MBrvfmyTHKOZdL29vWhsbERFRYW+zWg0wul0wu12D/o5gUAAgUBA//jChQsoLCyM+qxENDzt7e2YMmXKXdfEPD5XrlxBf38/bDZb2HabzYZz584N+jmVlZX49a9/fcf29vZ2aJoWlTmJKHJ+vx95eXnIzMz82rUxj89wVFRUYOPGjfrHA3dQ0zTGhygODeXpkJjHZ9KkSUhKSoLX6w3b7vV6YbfbB/0ci8UCi8USi/GIKEZifrbLbDajqKgI9fX1+rZQKIT6+no4HI5Yj0NEiij5sWvjxo0oLy/HvHnzsGDBAvzud79Dd3c3nnrqKRXjEJECSuLz5JNP4vLly9iyZQs8Hg+++c1v4p133rnjSWgiGruU/J7PSPn9flitVvh8Pj7hTBRHInls8rVdRKQE40NESjA+RKQE40NESjA+RKQE40NESjA+RKQE40NESjA+RKQE40NESjA+RKQE40NESjA+RKQE40NESjA+RKQE40NESjA+RKQE40NESjA+RKQE40NESjA+RKQE40NESjA+RKQE40NESjA+RKQE40NESjA+RKQE40NESjA+RKQE40NESjA+RKQE40NESjA+RKQE40NESjA+RKRExPE5evQoHn/8ceTm5sJgMGD//v1h+0UEW7ZswT333IPU1FQ4nU589tlnYWuuXbuG5cuXQ9M0ZGVlYeXKlbh+/fqI7ggRJZaI49Pd3Y25c+di586dg+7ftm0bduzYgT179uDYsWNIT09HSUkJenp69DXLly9HS0sL6urqUFtbi6NHj+KZZ54Z/r0gosQjIwBAampq9I9DoZDY7XZ56aWX9G2dnZ1isVjkjTfeEBGRM2fOCAA5ceKEvubQoUNiMBjkwoULQ/q+Pp9PAIjP5xvJ+EQ0yiJ5bI7qcz7nz5+Hx+OB0+nUt1mtVhQXF8PtdgMA3G43srKyMG/ePH2N0+mE0WjEsWPHBv26gUAAfr8/7EZEiW1U4+PxeAAANpstbLvNZtP3eTwe5OTkhO1PTk5Gdna2vuZ2lZWVsFqt+i0vL280xyYiBRLibFdFRQV8Pp9+a29vVz0SEY3QqMbHbrcDALxeb9h2r9er77Pb7ejo6Ajb39fXh2vXrulrbmexWKBpWtiNiBLbqManoKAAdrsd9fX1+ja/349jx47B4XAAABwOBzo7O9HY2KivOXz4MEKhEIqLi0dzHCKKY8mRfsL169fx97//Xf/4/PnzaGpqQnZ2NvLz87F+/Xr85je/wfTp01FQUIDnn38eubm5WLJkCQBg5syZePTRR/H0009jz549CAaDWLNmDX784x8jNzd31O4YEcW5SE+lHTlyRADccSsvLxeRm6fbn3/+ebHZbGKxWGTRokXS2toa9jWuXr0qy5Ytk4yMDNE0TZ566inp6uoa8gw81U4UnyJ5bBpERBS2b1j8fj+sVit8Ph+f/yGKI5E8NhPibBcRjT2MDxEpwfgQkRIRn+0iGgkRQX93NwKXLiHU04Ok9HRYcnNhtFhgMBhUj0cxxPhQzISCQfzzv/4LHW+/jZ4LFxAKBJCUmoq0f/kX2P/1X5E5Zw4MRh6MjxeMD8WE9PWho7YWl954A6Fb3l6l/8YNdDU344u2NuT9+79jwkMP8QhonOA/MxR1IgL/xx/D8+abYeG5VZ/Ph/9+9VX08HV74wbjQ1En/f3oqK1F/40bd10XvHIFl995Bwn4q2c0DIwPRV8ohMCFC0NaeuOzzyC9vVEeiOIB40NxpfvTT3Hj889Vj0ExwPhQ3AkFAqpHoBhgfCju8Meu8YHxobgTYnzGBcaH4osI4zNOMD4Udxif8YHxoegzGGBMSxvy8r7OzujNQnGD8aGoMxiNSCsoGPL6L3iqfVxgfCj6DAYYzGbVU1CcYXwoJoyMD92G8aGYYHzodowPRZ/BwPjQHRgfigmDyRTRer6yfexjfCgmDMlDf986CYWiOAnFC8aHoi7SdyaUvj6AARrzGB+KO9LXx6OfcYDxobjDI5/xgfGhuBPikc+4wPhQ3JFgkPEZBxgfiomk9HRgiNfk6r9+nW8oNg4wPhQTKV9elXQoei9f/torXVDiY3woJgwmE8CLAdItGB+KCaPJxCuRUhjGh2KCRz50O8aHYsJoMsEwxCecaXzg3waKCR750O0iik9lZSXmz5+PzMxM5OTkYMmSJWhtbQ1b09PTA5fLhYkTJyIjIwNlZWXwer1ha9ra2lBaWoq0tDTk5ORg06ZN6OvrG/m9obhlSE6OKD78PZ+xL6L4NDQ0wOVy4cMPP0RdXR2CwSAWL16M7u5ufc2GDRtw4MABVFdXo6GhARcvXsTSpUv1/f39/SgtLUVvby8++OAD7Nu3D3v37sWWLVtG715RwuPv+Yx9BhnBG6dcvnwZOTk5aGhowMMPPwyfz4fJkyfj9ddfx49+9CMAwLlz5zBz5ky43W4sXLgQhw4dwve//31cvHgRNpsNALBnzx784he/wOXLl2EewptO+f1+WK1W+Hw+aJo23PEphvquX0fLs88O6coUhqQkfOM//gMZM2dGfzAaVZE8Nkf0nI/P5wMAZGdnAwAaGxsRDAbhdDr1NTNmzEB+fj7cbjcAwO12Y/bs2Xp4AKCkpAR+vx8tLS2Dfp9AIAC/3x92o7GN1+4a+4Ydn1AohPXr1+Ohhx7CrFmzAAAejwdmsxlZWVlha202Gzwej77m1vAM7B/YN5jKykpYrVb9lpeXN9yxKQEIGJ/xYNjxcblcOH36NKqqqkZznkFVVFTA5/Ppt/b29qh/T1JIBBIMqp6Comzo7215izVr1qC2thZHjx7FlClT9O12ux29vb3o7OwMO/rxer2w2+36muPHj4d9vYGzYQNrbmexWGAZ4uuCKH4N+TecRdB75Up0hyHlIjryERGsWbMGNTU1OHz4MApuuwplUVERTCYT6uvr9W2tra1oa2uDw+EAADgcDjQ3N6Ojo0NfU1dXB03TUFhYOJL7QnHMaDIhderUoS0WQeDixajOQ+pFdOTjcrnw+uuv4+2330ZmZqb+HI3VakVqaiqsVitWrlyJjRs3Ijs7G5qmYe3atXA4HFi4cCEAYPHixSgsLMSKFSuwbds2eDwePPfcc3C5XDy6GcuMxptvq0H0pYjis3v3bgDAt7/97bDtr732Gn76058CALZv3w6j0YiysjIEAgGUlJRg165d+tqkpCTU1tZi9erVcDgcSE9PR3l5OV588cWR3ROKbwZDxJfPobEtovgM5VeCUlJSsHPnTuzcufMr19x33304ePBgJN+aEpzBYICR8aFb8LVdFBs88qHbMD4UG8O4ZDKvWjq2MT4UM7xqKd2K8aGYMBgMkb2qva8P4JHPmMb4UFwKBYO8cOAYx/hQXJJgkM/5jHGMD8UlCQb5Y9cYx/hQXArxqqVjHuNDMWNIShry2uDVq3xl+xjH+FDMpBYUDPl0e+/VqzefdKYxi/GhmElKSeEVLEjH+FDMGMzmsPiERHAtEEAXz2yNS8N6MzGi4Ri4ZLIA6AuF8He/Hz39/bgaCCAvPR3TNA1GHhmNG4wPxYzxyyMfEcF/d3fjRn8/fnfmDFp9PqQnJ+PfvvENPFlQwMPxcYLxoZi59VXt/mAQ///TT3Hmy0vp+INB/OHsWaQmJaHk3nuRoWhGih3+I0MxY7RYYExJAXDztV5dt53N6g2F8Knfj17+fs+4wPhQzJgmTED2ww/ffGMxAP83JwfJtzzHMy0zE/MmToTVZEKSxQKDkX89xzL+2EUxYzAaMcnpxNXDh5EbCOD/2WzQzGa85/EgJyUF383NxUyrFQaDARMeeQTJvBrtmMb4UEyl5OfjnmXLENq3D4KbFwj8P1lZMAC4JzUVk1JTkVFYCNuSJTzyGeMYH4opg9GIyY8+CgBI/s//RPY//6m/gNRgMkF74AFMWbkS5i8vwU1jF+NDMWc0mZBTWgrrgw/Cf+oUAh4PjKmpyJg5ExkzZyIpNVX1iBQDjA8pYTAakXLvvUi5917Vo5Ai/KGaiJRgfIhICcaHiJRgfIhICcaHiJRgfIhICcaHiJRgfIhICcaHiJRgfIhICcaHiJRgfIhICcaHiJRgfIhIiYjis3v3bsyZMweapkHTNDgcDhw6dEjf39PTA5fLhYkTJyIjIwNlZWXwer1hX6OtrQ2lpaVIS0tDTk4ONm3ahL6+vtG5N0SUMCKKz5QpU7B161Y0Njbi5MmT+M53voMnnngCLS0tAIANGzbgwIEDqK6uRkNDAy5evIilS5fqn9/f34/S0lL09vbigw8+wL59+7B3715s2bJldO8VEcU/GaEJEybIK6+8Ip2dnWIymaS6ulrfd/bsWQEgbrdbREQOHjwoRqNRPB6Pvmb37t2iaZoEAoGv/B49PT3i8/n0W3t7uwAQn8830vGJaBT5fL4hPzaH/ZxPf38/qqqq0N3dDYfDgcbGRgSDQTidTn3NjBkzkJ+fD7fbDQBwu92YPXs2bDabvqakpAR+v18/ehpMZWUlrFarfsvLyxvu2EQUJyKOT3NzMzIyMmCxWLBq1SrU1NSgsLAQHo8HZrMZWVlZYettNhs8Hg8AwOPxhIVnYP/Avq9SUVEBn8+n39rb2yMdm4jiTMTv4Xz//fejqakJPp8Pb731FsrLy9HQ0BCN2XQWiwUWiyWq34OIYivi+JjNZkybNg0AUFRUhBMnTuDll1/Gk08+id7eXnR2doYd/Xi9XtjtdgCA3W7H8ePHw77ewNmwgTVEND6M+Pd8QqEQAoEAioqKYDKZUF9fr+9rbW1FW1sbHA4HAMDhcKC5uRkdHR36mrq6OmiahsLCwpGOQkQJJKIjn4qKCjz22GPIz89HV1cXXn/9dbz33nt49913YbVasXLlSmzcuBHZ2dnQNA1r166Fw+HAwoULAQCLFy9GYWEhVqxYgW3btsHj8eC5556Dy+Xij1VE40xE8eno6MBPfvITXLp0CVarFXPmzMG7776L7373uwCA7du3w2g0oqysDIFAACUlJdi1a5f++UlJSaitrcXq1avhcDiQnp6O8vJyvPjii6N7r4go7hlEvrxWbQLx+/2wWq3w+XzQNE31OET0pUgem3xtFxEpwfgQkRKMDxEpwfgQkRKMDxEpwfgQkRKMDxEpwfgQkRKMDxEpwfgQkRKMDxEpwfgQkRKMDxEpwfgQkRKMDxEpwfgQkRKMDxEpwfgQkRKMDxEpwfgQkRKMDxEpwfgQkRKMDxEpwfgQkRKMDxEpwfgQkRKMDxEpwfgQkRKMDxEpwfgQkRKMDxEpwfgQkRKMDxEpwfgQkRKMDxEpMaL4bN26FQaDAevXr9e39fT0wOVyYeLEicjIyEBZWRm8Xm/Y57W1taG0tBRpaWnIycnBpk2b0NfXN5JRiCjBDDs+J06cwB//+EfMmTMnbPuGDRtw4MABVFdXo6GhARcvXsTSpUv1/f39/SgtLUVvby8++OAD7Nu3D3v37sWWLVuGfy+IKPHIMHR1dcn06dOlrq5OHnnkEVm3bp2IiHR2dorJZJLq6mp97dmzZwWAuN1uERE5ePCgGI1G8Xg8+prdu3eLpmkSCAQG/X49PT3i8/n0W3t7uwAQn883nPGJKEp8Pt+QH5vDOvJxuVwoLS2F0+kM297Y2IhgMBi2fcaMGcjPz4fb7QYAuN1uzJ49GzabTV9TUlICv9+PlpaWQb9fZWUlrFarfsvLyxvO2EQURyKOT1VVFT766CNUVlbesc/j8cBsNiMrKytsu81mg8fj0dfcGp6B/QP7BlNRUQGfz6ff2tvbIx2biOJMciSL29vbsW7dOtTV1SElJSVaM93BYrHAYrHE7PsRUfRFdOTT2NiIjo4OPPjgg0hOTkZycjIaGhqwY8cOJCcnw2azobe3F52dnWGf5/V6YbfbAQB2u/2Os18DHw+sIaKxL6L4LFq0CM3NzWhqatJv8+bNw/Lly/X/NplMqK+v1z+ntbUVbW1tcDgcAACHw4Hm5mZ0dHToa+rq6qBpGgoLC0fpbhFRvIvox67MzEzMmjUrbFt6ejomTpyob1+5ciU2btyI7OxsaJqGtWvXwuFwYOHChQCAxYsXo7CwECtWrMC2bdvg8Xjw3HPPweVy8UcronEkovgMxfbt22E0GlFWVoZAIICSkhLs2rVL35+UlITa2lqsXr0aDocD6enpKC8vx4svvjjaoxBRHDOIiKgeIlJ+vx9WqxU+nw+apqkeh4i+FMljk6/tIiIlGB8iUoLxISIlGB8iUoLxISIlGB8iUoLxISIlGB8iUoLxISIlGB8iUoLxISIlGB8iUoLxISIlGB8iUoLxISIlGB8iUoLxISIlGB8iUoLxISIlGB8iUoLxISIlGB8iUoLxISIlGB8iUoLxISIlGB8iUoLxISIlGB8iUoLxISIlGB8iUoLxISIlGB8iUoLxISIlGB8iUoLxISIlGB8iUoLxISIlklUPMBwiAgDw+/2KJyGiWw08Jgceo3eTkPG5evUqACAvL0/xJEQ0mK6uLlit1ruuScj4ZGdnAwDa2tq+9g7GG7/fj7y8PLS3t0PTNNXjDBnnjq1EnVtE0NXVhdzc3K9dm5DxMRpvPlVltVoT6n/MrTRNS8jZOXdsJeLcQz0g4BPORKQE40NESiRkfCwWC1544QVYLBbVo0QsUWfn3LGVqHNHwiBDOSdGRDTKEvLIh4gSH+NDREowPkSkBONDREowPkSkRELGZ+fOnZg6dSpSUlJQXFyM48ePK53n6NGjePzxx5GbmwuDwYD9+/eH7RcRbNmyBffccw9SU1PhdDrx2Wefha25du0ali9fDk3TkJWVhZUrV+L69etRnbuyshLz589HZmYmcnJysGTJErS2toat6enpgcvlwsSJE5GRkYGysjJ4vd6wNW1tbSgtLUVaWhpycnKwadMm9PX1RW3u3bt3Y86cOfpv/zocDhw6dCiuZx7M1q1bYTAYsH79+oSbfVRIgqmqqhKz2SyvvvqqtLS0yNNPPy1ZWVni9XqVzXTw4EH51a9+JX/6058EgNTU1ITt37p1q1itVtm/f798/PHH8oMf/EAKCgrkiy++0Nc8+uijMnfuXPnwww/lb3/7m0ybNk2WLVsW1blLSkrktddek9OnT0tTU5N873vfk/z8fLl+/bq+ZtWqVZKXlyf19fVy8uRJWbhwoXzrW9/S9/f19cmsWbPE6XTKqVOn5ODBgzJp0iSpqKiI2tx//vOf5S9/+Yt8+umn0traKr/85S/FZDLJ6dOn43bm2x0/flymTp0qc+bMkXXr1unbE2H20ZJw8VmwYIG4XC794/7+fsnNzZXKykqFU/2v2+MTCoXEbrfLSy+9pG/r7OwUi8Uib7zxhoiInDlzRgDIiRMn9DWHDh0Sg8EgFy5ciNnsHR0dAkAaGhr0OU0mk1RXV+trzp49KwDE7XaLyM3wGo1G8Xg8+prdu3eLpmkSCARiNvuECRPklVdeSYiZu7q6ZPr06VJXVyePPPKIHp9EmH00JdSPXb29vWhsbITT6dS3GY1GOJ1OuN1uhZN9tfPnz8Pj8YTNbLVaUVxcrM/sdruRlZWFefPm6WucTieMRiOOHTsWs1l9Ph+A/33XgMbGRgSDwbDZZ8yYgfz8/LDZZ8+eDZvNpq8pKSmB3+9HS0tL1Gfu7+9HVVUVuru74XA4EmJml8uF0tLSsBmBxPjzHk0J9ar2K1euoL+/P+wPHgBsNhvOnTunaKq783g8ADDozAP7PB4PcnJywvYnJycjOztbXxNtoVAI69evx0MPPYRZs2bpc5nNZmRlZd119sHu28C+aGlubobD4UBPTw8yMjJQU1ODwsJCNDU1xe3MAFBVVYWPPvoIJ06cuGNfPP95R0NCxYeix+Vy4fTp03j//fdVjzIk999/P5qamuDz+fDWW2+hvLwcDQ0Nqse6q/b2dqxbtw51dXVISUlRPY5yCfVj16RJk5CUlHTHs/9erxd2u13RVHc3MNfdZrbb7ejo6Ajb39fXh2vXrsXkfq1Zswa1tbU4cuQIpkyZom+32+3o7e1FZ2fnXWcf7L4N7IsWs9mMadOmoaioCJWVlZg7dy5efvnluJ65sbERHR0dePDBB5GcnIzk5GQ0NDRgx44dSE5Ohs1mi9vZoyGh4mM2m1FUVIT6+np9WygUQn19PRwOh8LJvlpBQQHsdnvYzH6/H8eOHdNndjgc6OzsRGNjo77m8OHDCIVCKC4ujtpsIoI1a9agpqYGhw8fRkFBQdj+oqIimEymsNlbW1vR1tYWNntzc3NYPOvq6qBpGgoLC6M2++1CoRACgUBcz7xo0SI0NzejqalJv82bNw/Lly/X/zteZ48K1c94R6qqqkosFovs3btXzpw5I88884xkZWWFPfsfa11dXXLq1Ck5deqUAJDf/va3curUKfnHP/4hIjdPtWdlZcnbb78tn3zyiTzxxBODnmp/4IEH5NixY/L+++/L9OnTo36qffXq1WK1WuW9996TS5cu6bcbN27oa1atWiX5+fly+PBhOXnypDgcDnE4HPr+gVO/ixcvlqamJnnnnXdk8uTJUT31u3nzZmloaJDz58/LJ598Ips3bxaDwSB//etf43bmr3Lr2a5Em32kEi4+IiK///3vJT8/X8xmsyxYsEA+/PBDpfMcOXJEANxxKy8vF5Gbp9uff/55sdlsYrFYZNGiRdLa2hr2Na5evSrLli2TjIwM0TRNnnrqKenq6orq3IPNDEBee+01fc0XX3whzz77rEyYMEHS0tLkhz/8oVy6dCns63z++efy2GOPSWpqqkyaNEl+/vOfSzAYjNrcP/vZz+S+++4Ts9kskydPlkWLFunhideZv8rt8Umk2UeK7+dDREok1HM+RDR2MD5EpATjQ0RKMD5EpATjQ0RKMD5EpATjQ0RKMD5EpATjQ0RKMD5EpATjQ0RK/A/QZBQcBR3tKwAAAABJRU5ErkJggg==\n",
      "text/plain": [
       "<Figure size 300x300 with 1 Axes>"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "text/plain": [
       "185.1729736328125"
      ]
     },
     "execution_count": 8,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "play(True)[-1]"
   ]
  }
 ],
 "metadata": {
  "colab": {
   "collapsed_sections": [],
   "name": "第9章-策略梯度算法.ipynb",
   "provenance": []
  },
  "kernelspec": {
   "display_name": "Python [conda env:pt39]",
   "language": "python",
   "name": "conda-env-pt39-py"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.9.13"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 1
}
